import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os



def _glorot_init(input_dim, output_dim):
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = torch.rand(input_dim, output_dim) * 2 * init_range - init_range
#     initial = torch.rand(input_dim, output_dim)
    return nn.Parameter(initial)



class GraphConvSparse(nn.Module):
    def __init__(self, input_dim, output_dim, activation=F.relu, **kwargs):
        super(GraphConvSparse, self).__init__(**kwargs)
        self.weight = _glorot_init(input_dim, output_dim)
        self.activation = activation

    def forward(self, inputs, adj):
        x = inputs
        # print(type(x), type(self.weight))
        x = torch.mm(x, self.weight)
        x = torch.mm(adj, x)
        outputs = self.activation(x)
        return outputs

class GAE(nn.Module):
    def __init__(self, input_dim, hidden1_dim, hidden2_dim, edge_type_num):
        super(GAE, self).__init__()
        self.base_gcn = GraphConvSparse(input_dim, hidden1_dim)
        self.gcn_mean = GraphConvSparse(
            hidden1_dim, hidden2_dim, activation=lambda x:x)
#         self.gcn_mean = GraphConvSparse(
#             hidden1_dim, hidden2_dim)
        self.output_fc = nn.Linear(hidden2_dim * 2, edge_type_num + 2)
        # self.fc_activation = torch.nn.Softmax(dim=1)


    def encode(self, X, adj):
        hidden = self.base_gcn(X, adj)
        z = self.gcn_mean(hidden, adj)
        return z

    def forward(self, X, adj, index_mat):
        Z = self.encode(X, adj)
        output = self.output_fc(Z[index_mat].view(len(index_mat), -1))
        # output = self.fc_activation(output)
        return output





















































